import safety_gym
import numpy as np
import tensorflow as tf
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
import time
import numpy as np
import torch
import torch.nn as nn
import gym
import sys
import os
os.chdir('/home/user/safety-starter-agents/safe_rl/PPO_Lagrangian_PyTorch/data/PPO-POINT-VAL/')
sys.path.append('/home/user/safety-starter-agents/safe_rl/PPO_Lagrangian_PyTorch/data/PPO-POINT-TRAIN/')
import core
# from utils.logx import EpochLogger
from utils.mpi_tools import mpi_fork, proc_id, num_procs, mpi_sum
torch.autograd.set_detect_anomaly(True)
import sysv_ipc
import torch.nn.functional as F
import copy
import multiprocessing
import pandas as pd


class Safety_NN(nn.Module):
    def __init__(self, n_state, n_class):
        super(Safety_NN, self).__init__()
        self.layer1 = nn.Linear(n_state, 128)
        self.layer2 = nn.Linear(128, 128)
        self.layer3 = nn.Linear(128, n_class)
    def forward(self, x):
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        return self.layer3(x)


n_sample_of_action = 252
n_class = 2
n_rank = 2
n_action = 2
n_NN = 1
n_observations = 60
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def validation(thread_order, agent, n_epoch, local_steps_per_epoch, render, max_ep_len, val_data_path):
    env_name = 'Safexp-PointGoal1-v0'
    env = gym.make(env_name)

    base_cost_path = val_data_path + "POF_COST"+str(thread_order)+".txt"
    base_reward_path = val_data_path + "POF_PF"+str(thread_order)+".txt"
    base_cv_path = val_data_path + "POF_CV"+str(thread_order)+".txt"
    base_act_path = val_data_path + "POF_ACT"+str(thread_order)+".txt"
    costfile = open(base_cost_path, "a")
    pffile = open(base_reward_path, "a")
    cvfile = open(base_cv_path, "a")
    acfile = open(base_act_path, "a")
    
    state, reward, done, cost, ep_ret, ep_cost, ep_len = env.reset(), 0, False, 0, 0, 0, 0
    acc_cost = 0
    acc_reward = 0
    for epoch in range(n_epoch):
        cv_denom = [0]
        cv_denom_counter = 0
        cv_counter = 0
        cv_index = 0
        for t in range(local_steps_per_epoch):
            if render and proc_id()==0: env.render()
            action, _, _, _ = agent.step(torch.as_tensor(state, dtype=torch.float32), 0)
            cv_denom.append(cv_denom[cv_index] + 1)

            acfile.write(str(action) + "\n")
            next_state, reward, done, info = env.step(action)
            cost = info.get('cost', 0)
            acc_cost += cost
            acc_reward += reward
            state = next_state
            ep_ret += reward
            ep_cost += cost
            ep_len += 1
            if cost != 0: 
                cv_counter += cv_denom[cv_index + 1] - cv_denom[max(0, cv_index - 59)]
                cv_denom_counter += cv_denom[cv_index + 1]
                cv_denom = [0]
                cv_index = 0
                state, reward, done, cost = env.reset(), 0, False, 0
            else: cv_index += 1

            terminal = done or (ep_len == max_ep_len)
            if terminal:
                print("RESET at epoch:%d, local_epoch:%d" %(epoch, t+1))
                if ep_len == max_ep_len:
                    costfile.write(str(acc_cost) + "\n")
                    pffile.write(str(acc_reward) + "\n")
                    if cost == 0: cv_denom_counter += cv_denom[cv_index]
                    cvfile.write(str(float(cv_counter/(cv_denom_counter+0.00000000001))) + "\n")
                    acc_cost = 0
                    acc_reward = 0
                    cv_denom = [0]
                    cv_denom_counter = 0
                    cv_counter = 0
                    cv_index = 0
                    costfile.close()
                    pffile.close()
                    cvfile.close()
                    acfile.close()
                    costfile = open(base_cost_path, "a")
                    pffile = open(base_reward_path, "a")
                    cvfile = open(base_cv_path, "a")
                    acfile = open(base_act_path, "a")
                state, reward, done, cost, ep_ret, ep_cost, ep_len = env.reset(), 0, False, 0, 0, 0, 0
    costfile.close()
    pffile.close()
    cvfile.close()
    acfile.close()


if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--env', type=str, default='Safexp-PointGoal1-v0')
    parser.add_argument('--hid', type=int, default=256)
    parser.add_argument('--l', type=int, default=2)
    parser.add_argument('--seed', '-s', type=int, default=0)
    parser.add_argument('--epochs', type=int, default=10000)
    parser.add_argument('--local_steps_per_epoch', type=int, default=1000)
    parser.add_argument('--len', type=int, default=1000)
    parser.add_argument('--exp_name', type=str, default='error')
    parser.add_argument('--checkpoint', type=str, default='-1')
    parser.add_argument('--render', action='store_true')
    args = parser.parse_args()
    
    val_exp_set = {"baseval_pretraining_ppo_point_barrier": -1,
               "baseval_pretraining_ppo_point_nolag": -1,
               "baseval_pretraining_ppo_point_lag0": -1,
               "baseval_pretraining_ppo_point_lag0.1": -1,
               "baseval_pretraining_ppo_point_lag0.2": -1,
               "baseval_pretraining_ppo_point_lag0.3": -1,
               "baseval_pretraining_ppo_point_lag0.5": -1,
               "baseval_pretraining_ppo_point_lag1": -1,
               "baseval_pretraining_ppo_point_lag1.5": -1,
               "baseval_pretraining_ppo_point_lag2": -1}
    
    if args.exp_name in val_exp_set:
        key_offset = val_exp_set[args.exp_name]
    else:
        print(f"'{args.exp_name}'is invalid validation experiment. Please check 'val_exp_set'.")

    from utils.run_utils import setup_logger_kwargs
    logger_kwargs = setup_logger_kwargs(args.exp_name, args.seed)
    
    # Seed setting
    seed = args.seed
    seed += 10000 * proc_id()
    np.random.seed(seed)
    torch.manual_seed(seed)
    if device.type == 'cuda': torch.cuda.manual_seed(seed)
    
    # ETC setting
    env_fn = lambda: gym.make(args.env)
    ac_kwargs = dict(hidden_sizes=[args.hid]*args.l)
    render = args.render
    epochs = args.epochs
    local_steps_per_epoch = args.local_steps_per_epoch
    max_ep_len = args.len
    storage_intest_number_array = [5000000, 5000000]
    default_action_margin = 0.05
    hazard_check_margin = 0.935
    context = [storage_intest_number_array[0], storage_intest_number_array[1]]  
    
    # Main setting
    exp_name_split = args.exp_name.split('_')
    val_info = exp_name_split[0]
    ckpt_info = exp_name_split[1]
    agent_info = exp_name_split[2]
    task_info = exp_name_split[3]
    if len(exp_name_split) == 5: agent_info_sub = exp_name_split[4]
    else: agent_info_sub = None
    ckpt_info_sub = args.checkpoint
    
    # Path setting
    val_data_path = args.exp_name + "/validation_" + ckpt_info_sub + "/"
    if not os.path.exists(val_data_path):
        os.makedirs(val_data_path)
        print("    CREATE DIRECTORY %s" %(val_data_path))
    else:
        print("    DIRECTORY ALREADY EXISTS %s" %(val_data_path))
        exit()
    train_default_path = "../PPO-POINT-TRAIN/" + agent_info + "_" + task_info + "_"
    if agent_info_sub != None: train_default_path += (agent_info_sub + "/")
    
    # Agent setting
    if agent_info == "ppo":
        actor_critic=core.MLPActorCritic_ppo_point_train
        agent = actor_critic(env_fn().observation_space,env_fn().action_space, **ac_kwargs)
        agent.eval() 
    else: exit()

    # Validation setting
    if val_info == "pofval": exit() 

    # Checkpoint setting
    if ckpt_info == "pretraining": 
        ckpt_path = train_default_path + "checkpoint/" + agent_info + "_" + task_info + "_"
        if agent_info_sub != None: ckpt_path += (agent_info_sub + "_")
        ckpt_path += (ckpt_info_sub + ".pt")
        try: # (1)agent, (2, conditional)intest
            backup = torch.load(ckpt_path)
            backup_dict = backup.state_dict()
            selected_net = ['pi', 'v', 'vc']
            selected_dict = {k: v for k, v in backup_dict.items() if any(net_name in k for net_name in selected_net)}
            agent.load_state_dict(selected_dict, strict=False)
            agent.eval()
            print("    LOAD AGENT %s" %(ckpt_path))
        except Exception as e:
            print(e)
            if val_info == "baseval" and agent_info_sub == "barrier":
                try: # (1)agent
                    agent.load_state_dict(backup["ac_ppo"])
                    agent.eval()
                    print("    LOAD AGENT %s" %(ckpt_path))
                except Exception as e:
                    print(e)
                    exit()
            else: exit()
    else: exit()

    # Run
    if val_info == "baseval":
        validation(0, agent, epochs, local_steps_per_epoch, render, max_ep_len, val_data_path)
    else: exit() 